{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference for OpenSora"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define global variables. You should change the following variables according to your setting."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# global variables\n",
    "ROOT = \"..\"\n",
    "cfg_path = f\"{ROOT}/configs/opensora-v1-2/inference/sample.py\"\n",
    "ckpt_path = \"/home/lishenggui/projects/sora/Open-Sora-dev/outputs/207-STDiT3-XL-2/epoch0-global_step9000/\"\n",
    "vae_path = f\"{ROOT}/pretrained_models/vae-pipeline\"\n",
    "save_dir = f\"{ROOT}/samples/samples_notebook/\"\n",
    "device = \"cuda:0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import necessary libraries and load the models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from pprint import pformat\n",
    "\n",
    "import colossalai\n",
    "import torch\n",
    "import torch.distributed as dist\n",
    "from colossalai.cluster import DistCoordinator\n",
    "from mmengine.runner import set_random_seed\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from opensora.acceleration.parallel_states import set_sequence_parallel_group\n",
    "from opensora.datasets import save_sample, is_img\n",
    "from opensora.datasets.aspect import get_image_size, get_num_frames\n",
    "from opensora.models.text_encoder.t5 import text_preprocessing\n",
    "from opensora.registry import MODELS, SCHEDULERS, build_module\n",
    "from opensora.utils.config_utils import read_config\n",
    "from opensora.utils.inference_utils import (\n",
    "    append_generated,\n",
    "    apply_mask_strategy,\n",
    "    collect_references_batch,\n",
    "    extract_json_from_prompts,\n",
    "    extract_prompts_loop,\n",
    "    get_save_path_name,\n",
    "    load_prompts,\n",
    "    prepare_multi_resolution_info,\n",
    ")\n",
    "from opensora.utils.misc import all_exists, create_logger, is_distributed, is_main_process, to_torch_dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# == parse configs ==\n",
    "cfg = read_config(cfg_path)\n",
    "cfg.model.from_pretrained = ckpt_path\n",
    "cfg.vae.from_pretrained = vae_path\n",
    "\n",
    "# == device and dtype ==\n",
    "cfg_dtype = cfg.get(\"dtype\", \"fp32\")\n",
    "assert cfg_dtype in [\"fp16\", \"bf16\", \"fp32\"], f\"Unknown mixed precision {cfg_dtype}\"\n",
    "dtype = to_torch_dtype(cfg.get(\"dtype\", \"bf16\"))\n",
    "torch.backends.cuda.matmul.allow_tf32 = True\n",
    "torch.backends.cudnn.allow_tf32 = True\n",
    "\n",
    "set_random_seed(seed=cfg.get(\"seed\", 1024))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# == build text-encoder and vae ==\n",
    "text_encoder = build_module(cfg.text_encoder, MODELS, device=device)\n",
    "vae = build_module(cfg.vae, MODELS).to(device, dtype).eval()\n",
    "\n",
    "# == build diffusion model ==\n",
    "input_size = (None, None, None)\n",
    "latent_size = vae.get_latent_size(input_size)\n",
    "model = (\n",
    "    build_module(\n",
    "        cfg.model,\n",
    "        MODELS,\n",
    "        input_size=latent_size,\n",
    "        in_channels=vae.out_channels,\n",
    "        caption_channels=text_encoder.output_dim,\n",
    "        model_max_length=text_encoder.model_max_length,\n",
    "    )\n",
    "    .to(device, dtype)\n",
    "    .eval()\n",
    ")\n",
    "text_encoder.y_embedder = model.y_embedder  # HACK: for classifier-free guidance"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define inference function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_idx = 0\n",
    "multi_resolution = cfg.get(\"multi_resolution\", None)\n",
    "batch_size = cfg.get(\"batch_size\", 1)\n",
    "\n",
    "\n",
    "def inference(\n",
    "    prompts=cfg.get(\"prompt\", None),\n",
    "    image_size=None,\n",
    "    num_frames=None,\n",
    "    resolution=None,\n",
    "    aspect_ratio=None,\n",
    "    mask_strategy=None,\n",
    "    reference_path=None,\n",
    "    num_sampling_steps=None,\n",
    "    cfg_scale=None,\n",
    "    seed=None,\n",
    "    fps=cfg.fps,\n",
    "    num_sample=cfg.get(\"num_sample\", 1),\n",
    "    loop=cfg.get(\"loop\", 1),\n",
    "    condition_frame_length=cfg.get(\"condition_frame_length\", 5),\n",
    "    align=cfg.get(\"align\", None),\n",
    "    sample_name=cfg.get(\"sample_name\", None),\n",
    "    prompt_as_path=cfg.get(\"prompt_as_path\", False),\n",
    "    disable_progress=False,\n",
    "):\n",
    "    global start_idx\n",
    "    os.makedirs(save_dir, exist_ok=True)\n",
    "    if seed is not None:\n",
    "        set_random_seed(seed=seed)\n",
    "    if not isinstance(prompts, list):\n",
    "        prompts = [prompts]\n",
    "    if mask_strategy is None:\n",
    "        mask_strategy = [\"\"] * len(prompts)\n",
    "    if reference_path is None:\n",
    "        reference_path = [\"\"] * len(prompts)\n",
    "    save_fps = cfg.fps // cfg.get(\"frame_interval\", 1)\n",
    "    if num_sampling_steps is not None:\n",
    "        cfg.scheduler[\"num_sampling_steps\"] = num_sampling_steps\n",
    "    if cfg_scale is not None:\n",
    "        cfg.scheduler[\"scale\"] = cfg_scale\n",
    "    scheduler = build_module(cfg.scheduler, SCHEDULERS)\n",
    "    ret_path = []\n",
    "\n",
    "    # == prepare video size ==\n",
    "    if image_size is None:\n",
    "        assert (\n",
    "            resolution is not None and aspect_ratio is not None\n",
    "        ), \"resolution and aspect_ratio must be provided if image_size is not provided\"\n",
    "        image_size = get_image_size(resolution, aspect_ratio)\n",
    "    num_frames = get_num_frames(num_frames)\n",
    "    input_size = (num_frames, *image_size)\n",
    "    latent_size = vae.get_latent_size(input_size)\n",
    "\n",
    "    # == Iter over all samples ==\n",
    "    for i in tqdm(range(0, len(prompts), batch_size), disable=disable_progress):\n",
    "        # == prepare batch prompts ==\n",
    "        batch_prompts = prompts[i : i + batch_size]\n",
    "        ms = mask_strategy[i : i + batch_size]\n",
    "        refs = reference_path[i : i + batch_size]\n",
    "\n",
    "        batch_prompts, refs, ms = extract_json_from_prompts(batch_prompts, refs, ms)\n",
    "        refs = collect_references_batch(refs, vae, image_size)\n",
    "\n",
    "        # == multi-resolution info ==\n",
    "        model_args = prepare_multi_resolution_info(\n",
    "            multi_resolution, len(batch_prompts), image_size, num_frames, fps, device, dtype\n",
    "        )\n",
    "\n",
    "        # == Iter over number of sampling for one prompt ==\n",
    "        for k in range(num_sample):\n",
    "            # == prepare save paths ==\n",
    "            save_paths = [\n",
    "                get_save_path_name(\n",
    "                    save_dir,\n",
    "                    sample_name=sample_name,\n",
    "                    sample_idx=start_idx + idx,\n",
    "                    prompt=batch_prompts[idx],\n",
    "                    prompt_as_path=prompt_as_path,\n",
    "                    num_sample=num_sample,\n",
    "                    k=k,\n",
    "                )\n",
    "                for idx in range(len(batch_prompts))\n",
    "            ]\n",
    "\n",
    "            # NOTE: Skip if the sample already exists\n",
    "            # This is useful for resuming sampling VBench\n",
    "            if prompt_as_path and all_exists(save_paths):\n",
    "                continue\n",
    "\n",
    "            # == Iter over loop generation ==\n",
    "            video_clips = []\n",
    "            for loop_i in range(loop):\n",
    "                batch_prompts_loop = extract_prompts_loop(batch_prompts, loop_i)\n",
    "                batch_prompts_cleaned = [text_preprocessing(prompt) for prompt in batch_prompts_loop]\n",
    "\n",
    "                # == loop ==\n",
    "                if loop_i > 0:\n",
    "                    refs, ms = append_generated(vae, video_clips[-1], refs, ms, loop_i, condition_frame_length)\n",
    "\n",
    "                # == sampling ==\n",
    "                z = torch.randn(len(batch_prompts), vae.out_channels, *latent_size, device=device, dtype=dtype)\n",
    "                masks = apply_mask_strategy(z, refs, ms, loop_i, align=align)\n",
    "                samples = scheduler.sample(\n",
    "                    model,\n",
    "                    text_encoder,\n",
    "                    z=z,\n",
    "                    prompts=batch_prompts_cleaned,\n",
    "                    device=device,\n",
    "                    additional_args=model_args,\n",
    "                    progress=False,\n",
    "                    mask=masks,\n",
    "                )\n",
    "                samples = vae.decode(samples.to(dtype), num_frames=num_frames)\n",
    "                video_clips.append(samples)\n",
    "\n",
    "            # == save samples ==\n",
    "            if is_main_process():\n",
    "                for idx, batch_prompt in enumerate(batch_prompts):\n",
    "                    save_path = save_paths[idx]\n",
    "                    video = [video_clips[i][idx] for i in range(loop)]\n",
    "                    for i in range(1, loop):\n",
    "                        video[i] = video[i][:, condition_frame_length:]\n",
    "                    video = torch.cat(video, dim=1)\n",
    "                    path = save_sample(\n",
    "                        video,\n",
    "                        fps=save_fps,\n",
    "                        save_path=save_path,\n",
    "                        verbose=False,\n",
    "                    )\n",
    "                    ret_path.append(path)\n",
    "        start_idx += len(batch_prompts)\n",
    "    return ret_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import Video, Image, display\n",
    "\n",
    "\n",
    "def display_results(paths):\n",
    "    for path in paths:\n",
    "        if is_img(path):\n",
    "            display(Image(path))\n",
    "        else:\n",
    "            display(Video(path, embed=True))\n",
    "\n",
    "\n",
    "def reset_start_idx():\n",
    "    global start_idx\n",
    "    start_idx = 0\n",
    "\n",
    "\n",
    "ALL_ASPECT_RATIO = [\"1:1\", \"16:9\", \"9:16\", \"3:4\", \"4:3\", \"1:2\", \"2:1\"]\n",
    "\n",
    "\n",
    "def inference_all_aspects(prompts, resolution, num_frames, *args, **kwargs):\n",
    "    paths = []\n",
    "    for aspect_ratio in tqdm(ALL_ASPECT_RATIO):\n",
    "        paths.extend(\n",
    "            inference(\n",
    "                prompts,\n",
    "                resolution=resolution,\n",
    "                num_frames=num_frames,\n",
    "                aspect_ratio=aspect_ratio,\n",
    "                disable_progress=True,\n",
    "                *args,\n",
    "                **kwargs\n",
    "            )\n",
    "        )\n",
    "    return paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference for OpenSora"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample code for inference for OpenSora."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "paths = inference(\n",
    "    [\"a man.\", \"a woman\"],\n",
    "    resolution=\"240p\",\n",
    "    aspect_ratio=\"1:1\",\n",
    "    num_frames=\"1x\",\n",
    "    num_sampling_steps=30,\n",
    "    cfg_scale=7.0,\n",
    ")\n",
    "display_results(paths)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample all aspect ratios."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROMPT = \"a boy.\"\n",
    "paths = inference_all_aspects(\n",
    "    PROMPT,\n",
    "    resolution=\"240p\",\n",
    "    num_frames=\"1x\",\n",
    "    num_sampling_steps=30,\n",
    "    cfg_scale=7.0,\n",
    ")\n",
    "display_results(paths)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample all resolution and length."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROMPT = \"a boy.\"\n",
    "sample_cfg = {\n",
    "    \"144p\": [1, \"1x\", \"2x\", \"4x\", \"8x\"],\n",
    "    \"240p\": [1, \"1x\", \"2x\", \"4x\", \"8x\"],\n",
    "    \"360p\": [1, \"1x\", \"2x\", \"4x\"],\n",
    "    \"480p\": [1, \"1x\", \"2x\", \"4x\"],\n",
    "    \"720p\": [1, \"1x\", \"2x\"],\n",
    "}\n",
    "all_paths = []\n",
    "for resolution, num_frames in sample_cfg.items():\n",
    "    for num_frame in num_frames:\n",
    "        print(f\"Resolution: {resolution}, Num Frames: {num_frame}\")\n",
    "        paths = inference(\n",
    "            PROMPT,\n",
    "            resolution=resolution,\n",
    "            num_frames=num_frame,\n",
    "            aspect_ratio=\"9:16\",\n",
    "            num_sampling_steps=30,\n",
    "            cfg_scale=7.0,\n",
    "            disable_progress=True,\n",
    "        )\n",
    "        display_results(paths)\n",
    "        all_paths.extend(paths)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sample all resolution, length, and aspect ratios."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PROMPT = \"a boy.\"\n",
    "sample_cfg = {\n",
    "    \"144p\": [1, \"1x\", \"2x\", \"4x\", \"8x\"],\n",
    "    \"240p\": [1, \"1x\", \"2x\", \"4x\", \"8x\"],\n",
    "    \"360p\": [1, \"1x\", \"2x\", \"4x\"],\n",
    "    \"480p\": [1, \"1x\", \"2x\", \"4x\"],\n",
    "    \"720p\": [1, \"1x\", \"2x\"],\n",
    "}\n",
    "all_paths = []\n",
    "for resolution, num_frames in sample_cfg.items():\n",
    "    for num_frame in num_frames:\n",
    "        paths = inference_all_aspects(\n",
    "            PROMPT,\n",
    "            resolution=resolution,\n",
    "            num_frames=num_frames,\n",
    "            num_sampling_steps=30,\n",
    "            cfg_scale=7.0,\n",
    "        )\n",
    "        display_results(paths)\n",
    "        all_paths.extend(paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "opensora",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
